import os
import h5py
import pickle
from tqdm import tqdm
import numpy as np
import ujson as json
import jax.numpy as jnp

def new_get_trj_idx_reward_model(relabeled_dataset, env, terminate_on_end=False, **kwargs):
    if not hasattr(env, 'get_dataset'):
        dataset = kwargs['dataset']
    else:
        dataset = env.get_dataset()
    N = dataset['rewards'].shape[0]
    N1 = relabeled_dataset.rewards.shape[0]
    
    use_timeouts = False
    if 'timeouts' in dataset:
        use_timeouts = True

    episode_step = 0
    episode_reward = 0
    episode_return = []
    for i in range(N-1):
        if env.spec and 'maze' in env.spec.id:
            done_bool = sum(dataset['infos/goal'][i+1] - dataset['infos/goal'][i]) > 0
        else:
            done_bool = bool(dataset['terminals'][i])
        if use_timeouts:
            final_timestep = dataset['timeouts'][i]
        else:
            final_timestep = (episode_step == env._max_episode_steps - 1)
        episode_reward += relabeled_dataset.rewards[i]
        if done_bool or final_timestep:
            episode_step = 0
            episode_return.append(episode_reward)
            episode_reward = 0
    episode_return = np.array(episode_return)
    index = 700
    idx = np.argpartition(episode_return, -index)[-index:]
    margin_episode_return = min(episode_return[idx])

    episode_step = 0
    start_idx, data_idx = 0, 0
    trj_idx_list = []
    episode_reward = 0
    for i in range(N-1):
        episode_reward += relabeled_dataset.rewards[i]

        if env.spec and 'maze' in env.spec.id:
            done_bool = sum(dataset['infos/goal'][i+1] - dataset['infos/goal'][i]) > 0
        else:
            done_bool = bool(dataset['terminals'][i])
        if use_timeouts:
            final_timestep = dataset['timeouts'][i]
        else:
            final_timestep = (episode_step == env._max_episode_steps - 1)
        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            if episode_reward >= margin_episode_return:
                trj_idx_list.append([start_idx, data_idx-1])
            episode_reward = 0
            start_idx = data_idx
            continue
        if done_bool or final_timestep:
            episode_step = 0
            if episode_reward >= margin_episode_return:
                trj_idx_list.append([start_idx, data_idx])
            episode_reward = 0
            start_idx = data_idx + 1
        episode_step += 1
        data_idx += 1
        
    trj_idx_list.append([start_idx, data_idx])
    return trj_idx_list


def new_get_trj_idx(env, terminate_on_end=False, **kwargs):
    if not hasattr(env, 'get_dataset'):
        dataset = kwargs['dataset']
    else:
        dataset = env.get_dataset()
    N = dataset['rewards'].shape[0]
    
    use_timeouts = False
    if 'timeouts' in dataset:
        use_timeouts = True

    episode_step = 0
    start_idx, data_idx = 0, 0
    trj_idx_list = []
    for i in range(N-1):
        if env.spec and 'maze' in env.spec.id:
            done_bool = sum(dataset['infos/goal'][i+1] - dataset['infos/goal'][i]) > 0
        else:
            done_bool = bool(dataset['terminals'][i])
        if use_timeouts:
            final_timestep = dataset['timeouts'][i]
        else:
            final_timestep = (episode_step == env._max_episode_steps - 1)
        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            trj_idx_list.append([start_idx, data_idx-1])
            start_idx = data_idx
            continue  
        if done_bool or final_timestep:
            episode_step = 0
            trj_idx_list.append([start_idx, data_idx])
            start_idx = data_idx + 1
            
        episode_step += 1
        data_idx += 1
        
    trj_idx_list.append([start_idx, data_idx])
    return trj_idx_list


def select_query(critic, seg_obs_1, seg_obs_2, seq_act_1, seq_act_2, query_index):
    batch_num, query_len, obs_dim = seg_obs_1.shape[0], seg_obs_1.shape[1], seg_obs_1.shape[2]
    act_dim = seq_act_1.shape[2]
    reshape_seg_obs1 = seg_obs_1.reshape(-1, obs_dim)
    reshape_seg_obs2 = seg_obs_2.reshape(-1, obs_dim)
    reshape_seq_act1 = seq_act_1.reshape(-1, act_dim)
    reshape_seq_act2 = seq_act_2.reshape(-1, act_dim)
    
    q1, q2 = critic(reshape_seg_obs1, reshape_seq_act1)
    q3, q4 = critic(reshape_seg_obs2, reshape_seq_act2)
    q1 = q1.reshape(batch_num, query_len)
    q2 = q2.reshape(batch_num, query_len)
    q3 = q3.reshape(batch_num, query_len)
    q4 = q4.reshape(batch_num, query_len)

    value_diff_1 = q1 - q2
    value_diff_2 = q3 - q4
    value_diff = value_diff_1 - value_diff_2
    value_diff_list = np.abs(np.sum(value_diff, axis=-1))
        
    topkindex = value_diff_list.argsort()[-2:]
    return topkindex


def init_pref_dataset(dataset, num_query, len_query):
    observation_dim = dataset["observations"].shape[-1]
    action_dim = dataset["actions"].shape[-1]

    total_reward_seq_1, total_reward_seq_2 = np.zeros((num_query, len_query)), np.zeros((num_query, len_query))
    total_obs_seq_1, total_obs_seq_2 = np.zeros((num_query, len_query, observation_dim)), np.zeros((num_query, len_query, observation_dim))
    total_next_obs_seq_1, total_next_obs_seq_2 = np.zeros((num_query, len_query, observation_dim)), np.zeros((num_query, len_query, observation_dim))
    total_act_seq_1, total_act_seq_2 = np.zeros((num_query, len_query, action_dim)), np.zeros((num_query, len_query, action_dim))
    total_timestep_1, total_timestep_2 = np.zeros((num_query, len_query), dtype=np.int32), np.zeros((num_query, len_query), dtype=np.int32)
    start_indices_1, start_indices_2 = np.zeros(num_query), np.zeros(num_query)
    time_indices_1, time_indices_2 = np.zeros(num_query), np.zeros(num_query)
    traj_return_1, traj_return_2 = np.zeros((num_query, len_query)), np.zeros((num_query, len_query))
    
    pref_dataset = {'r1': total_reward_seq_1, 'r2': total_reward_seq_2,
                'o1': total_obs_seq_1, 'o2': total_obs_seq_2, 
                'no1': total_next_obs_seq_1, 'no2': total_next_obs_seq_2,
                'a1': total_act_seq_1, 'a2': total_act_seq_2,
                'tt1': total_timestep_1, 'tt2': total_timestep_2,
                's1': start_indices_1, 's2': start_indices_2,
                't1': time_indices_1, 't2': time_indices_2,
                'traj1': traj_return_1, 'traj2': traj_return_2}
    return pref_dataset


def construct_query_dataset(dataset, pref_dataset_dict, prepare_num_query, trj_idx_list, labeler_info, len_query, skip_flag, trj_len_list):
    total_reward_seq_1, total_reward_seq_2 = pref_dataset_dict['r1'], pref_dataset_dict['r2']
    total_obs_seq_1, total_obs_seq_2 = pref_dataset_dict['o1'], pref_dataset_dict['o2']
    total_next_obs_seq_1, total_next_obs_seq_2 = pref_dataset_dict['no1'], pref_dataset_dict['no2']
    total_act_seq_1, total_act_seq_2 = pref_dataset_dict['a1'], pref_dataset_dict['a2']
    total_timestep_1, total_timestep_2 = pref_dataset_dict['tt1'], pref_dataset_dict['tt2']
    start_indices_1, start_indices_2 = pref_dataset_dict['s1'], pref_dataset_dict['s2']
    time_indices_1, time_indices_2 = pref_dataset_dict['t1'], pref_dataset_dict['t2']
    traj_return_1, traj_return_2 = pref_dataset_dict['traj1'], pref_dataset_dict['traj2']

    for query_count in tqdm(range(prepare_num_query), desc="get queries"):
        temp_count = 0
        labeler = -1
        while(temp_count < 2):
            trj_idx = np.random.choice(np.arange(len(trj_idx_list) - 1)[np.logical_not(labeler_info)])
            len_trj = trj_len_list[trj_idx]

            if len_trj > len_query and (temp_count == 0 or labeler_info[trj_idx] == labeler):
                labeler = labeler_info[trj_idx]
                time_idx = np.random.choice(len_trj - len_query + 1)
                start_idx = trj_idx_list[trj_idx][0] + time_idx
                end_idx = start_idx + len_query

                assert end_idx <= trj_idx_list[trj_idx][1] + 1

                reward_seq = dataset['rewards'][start_idx:end_idx]
                obs_seq = dataset['observations'][start_idx:end_idx]
                next_obs_seq = dataset['next_observations'][start_idx:end_idx]
                act_seq = dataset['actions'][start_idx:end_idx]
                traj_return_seq = dataset['traj return'][start_idx:end_idx]
                timestep_seq = np.arange(1, len_query + 1)

                # skip flag 1: skip queries with equal rewards.
                if skip_flag == 1 and temp_count == 1:
                    if np.sum(total_reward_seq_1[-1]) == np.sum(reward_seq):
                        continue
                # skip flag 2: keep queries with equal reward until 50% of num_query.
                if skip_flag == 2 and temp_count == 1 and query_count < int(0.5*prepare_num_query):
                    if np.sum(total_reward_seq_1[-1]) == np.sum(reward_seq):
                        continue
                # skip flag 3: keep queries with equal reward until 20% of num_query.
                if skip_flag == 3 and temp_count == 1 and query_count < int(0.2*prepare_num_query):
                    if np.sum(total_reward_seq_1[-1]) == np.sum(reward_seq):
                        continue

                if temp_count == 0:
                    start_indices_1[query_count] = start_idx
                    time_indices_1[query_count] = time_idx
                    total_reward_seq_1[query_count] = reward_seq
                    total_obs_seq_1[query_count] = obs_seq
                    total_next_obs_seq_1[query_count] = next_obs_seq
                    total_act_seq_1[query_count] = act_seq
                    total_timestep_1[query_count] = timestep_seq
                    traj_return_1[query_count] = traj_return_seq
                else:
                    start_indices_2[query_count] = start_idx
                    time_indices_2[query_count] = time_idx
                    total_reward_seq_2[query_count] = reward_seq
                    total_obs_seq_2[query_count] = obs_seq
                    total_next_obs_seq_2[query_count] = next_obs_seq
                    total_act_seq_2[query_count] = act_seq
                    total_timestep_2[query_count] = timestep_seq
                    traj_return_2[query_count] = traj_return_seq
                    
                temp_count += 1
    return total_reward_seq_1.copy(), total_reward_seq_2.copy(), total_obs_seq_1.copy(), total_obs_seq_2.copy(), total_next_obs_seq_1.copy(), total_next_obs_seq_2.copy(), total_act_seq_1.copy(), total_act_seq_2.copy(), total_timestep_1.copy(), total_timestep_2.copy(), start_indices_1.copy(), start_indices_2.copy(), traj_return_1.copy(), traj_return_2.copy()


# def split_into_trajectories(observations, actions, rewards, dones_float,
#                             next_observations):
#     trajs = [[]]

#     for i in range(len(observations)):
#         trajs[-1].append((observations[i], actions[i], rewards[i],
#                         dones_float[i], next_observations[i]))
#         if dones_float[i] == 1.0 and i + 1 < len(observations):
#             trajs.append([])

#     return trajs


def split_into_trajectories(observations, actions, rewards, masks, dones_float,
                            next_observations):
    trajs = [[]]
    traj_return = []
    episode_reward = 0
    for i in tqdm(range(len(observations)), desc="split"):
        trajs[-1].append((observations[i], actions[i], rewards[i], masks[i],
                          dones_float[i], next_observations[i]))
        episode_reward += rewards[i]
        if dones_float[i] == 1.0 and i < len(observations):
            trajs.append([])
            traj_return.append(episode_reward)
            episode_reward = 0
    
    return trajs, traj_return


# def compute_mc_target(dataset):
#     dones_float = np.zeros_like(dataset['rewards'])
#     for i in range(len(dones_float) - 1):
#         if np.linalg.norm(dataset['observations'][i + 1] -
#                             dataset['next_observations'][i]
#                             ) > 1e-6 or dataset['terminals'][i] == 1.0:
#             dones_float[i] = 1
#         else:
#             dones_float[i] = 0
    
#     dones_float[-1] = 1
#     trajs = split_into_trajectories(dataset['observations'], dataset['actions'], 
#                     dataset['rewards'], dones_float, dataset['next_observations'])
#     mc_targets = []
#     for index in range(len(trajs)):
#         traj_length = len(trajs[index])
#         mc_target = []
#         rewards = []
#         reward = 0
#         for step in range(traj_length):
#             reward = trajs[index][-(step+1)][2] + 0.99 * reward
#             mc_target.append(reward)
#             rewards.append(trajs[index][-(step+1)][2])
#         # print(mc_target[::-1], '\n')
#         # print(np.array(rewards[::-1]).shape, traj_length, rewards[::-1])
#         for value in mc_target[::-1]:
#             mc_targets.append(value)
#     dataset['traj return'] = np.array(mc_targets)
#     print(dataset['traj return'])
#     return dataset

def add_return(dataset):
    dataset['masks'] = 1.0 - dataset['terminals']
    dones_float = np.zeros_like(dataset['rewards'])
    for i in range(len(dones_float) - 1):
        if np.linalg.norm(dataset['observations'][i + 1] -
                            dataset['next_observations'][i]
                            ) > 1e-5 or dataset['terminals'][i] == 1.0:
            dones_float[i] = 1
        else:
            dones_float[i] = 0
    dones_float[-1] = 1
    trajs, traj_return = split_into_trajectories(
        dataset['observations'],
        dataset['actions'],
        dataset['rewards'],
        dataset['masks'],
        dones_float,
        dataset['next_observations']
    )
    __traj_return = np.zeros_like(dataset['rewards'])
    last_traj_length = 0
    print(len(traj_return))
    for kk in range(len(trajs)-1):
        jj = trajs[kk]
        traj_length = last_traj_length + len(jj)
        __traj_return[last_traj_length:traj_length] = np.ones(len(jj)) * traj_return[kk]
        last_traj_length = traj_length
    dataset['traj return'] = __traj_return
    print(traj_length, len(dataset['rewards']), __traj_return)
    return dataset


def get_queries_from_multi(env, dataset, relabeled_dataset, num_query, len_query, reward_model, batch, query_index, balance=False, label_type=0, skip_flag=0):
    prepare_num_query = 10000
    init_query_number = 2
    
    dataset = add_return(dataset)

    if relabeled_dataset == None:
        trj_idx_list = np.array(new_get_trj_idx(env, dataset=dataset))
    else:
        trj_idx_list = np.array(new_get_trj_idx_reward_model(relabeled_dataset, env, dataset=dataset))
    print('*'*30)
    print('traj dataset size: ', len(trj_idx_list))
    print('*'*30)
    trj_len_list = trj_idx_list[:,1] - trj_idx_list[:,0] + 1
    labeler_info = np.zeros(len(trj_idx_list) - 1)
    assert max(trj_len_list) > len_query
    
    pref_dataset_dict = init_pref_dataset(dataset, prepare_num_query, len_query)
    seg_reward_1, seg_reward_2, seg_obs_1, seg_obs_2, seg_next_obs_1, seg_next_obs_2, seq_act_1, seq_act_2, seq_timestep_1, seq_timestep_2, start_indices_1, start_indices_2, traj_return_1, traj_return_2 = construct_query_dataset(dataset, pref_dataset_dict, prepare_num_query, trj_idx_list, labeler_info, len_query, skip_flag, trj_len_list)
    
    if label_type == 0: # perfectly rational
        sum_r_t_1 = np.sum(seg_reward_1, axis=1)
        sum_r_t_2 = np.sum(seg_reward_2, axis=1)
        binary_label = 1*(sum_r_t_1 < sum_r_t_2)
        rational_labels = np.zeros((len(binary_label), 2))
        rational_labels[np.arange(binary_label.size), binary_label] = 1.0
        rational_labels = binary_label

    if reward_model == None:
        rational_labels = rational_labels[:init_query_number]
        seg_obs_1 = seg_obs_1[:init_query_number]
        seg_next_obs_1 = seg_next_obs_1[:init_query_number]
        seq_act_1 = seq_act_1[:init_query_number]
        seg_obs_2 = seg_obs_2[:init_query_number]
        seg_next_obs_2 = seg_next_obs_2[:init_query_number]
        seq_act_2 = seq_act_2[:init_query_number]
        seq_timestep_1 = seq_timestep_1[:init_query_number]
        seq_timestep_2 = seq_timestep_2[:init_query_number]
        start_indices_1 = start_indices_1[:init_query_number]
        start_indices_2 = start_indices_2[:init_query_number]
        traj_return_1 = traj_return_1[:init_query_number]
        traj_return_2 = traj_return_2[:init_query_number]

        batch = {}
        batch['labels'] = rational_labels
        batch['observations'] = seg_obs_1 # for compatibility, remove "_1"
        batch['next_observations'] = seg_next_obs_1
        batch['actions'] = seq_act_1
        batch['observations_2'] = seg_obs_2
        batch['next_observations_2'] = seg_next_obs_2
        batch['actions_2'] = seq_act_2
        batch['timestep_1'] = seq_timestep_1
        batch['timestep_2'] = seq_timestep_2
        batch['start_indices'] = start_indices_1
        batch['start_indices_2'] = start_indices_2
        batch['traj return_1'] = traj_return_1
        batch['traj return_2'] = traj_return_2
    else:
        import train_offline_iter
        agent, relabeled_dataset = train_offline_iter.train_q_network(reward_model, query_index)
        topkindex = select_query(agent.critic, seg_obs_1, seg_obs_2, seq_act_1, seq_act_2, query_index)
        batch['labels'] = np.concatenate((batch['labels'], rational_labels[topkindex]), axis=0)
        batch['observations'] = np.concatenate((batch['observations'], seg_obs_1[topkindex]), axis=0)
        batch['next_observations'] = np.concatenate((batch['next_observations'], seg_next_obs_1[topkindex]), axis=0)
        batch['actions'] = np.concatenate((batch['actions'], seq_act_1[topkindex]), axis=0)
        batch['observations_2'] = np.concatenate((batch['observations_2'], seg_obs_2[topkindex]), axis=0)
        batch['next_observations_2'] = np.concatenate((batch['next_observations_2'], seg_next_obs_2[topkindex]), axis=0)
        batch['actions_2'] = np.concatenate((batch['actions_2'], seq_act_2[topkindex]), axis=0)
        batch['timestep_1'] = np.concatenate((batch['timestep_1'], seq_timestep_1[topkindex]), axis=0)
        batch['timestep_2'] = np.concatenate((batch['timestep_2'], seq_timestep_2[topkindex]), axis=0)
        batch['start_indices'] = np.concatenate((batch['start_indices'], start_indices_1[topkindex]), axis=0)
        batch['start_indices_2'] = np.concatenate((batch['start_indices_2'], start_indices_2[topkindex]), axis=0)
        batch['traj return_1'] = np.concatenate((batch['traj return_1'], traj_return_1[topkindex]), axis=0)
        batch['traj return_2'] = np.concatenate((batch['traj return_2'], traj_return_2[topkindex]), axis=0)
    return batch, relabeled_dataset